import scipy.io
import torch
import h5py
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.datasets import CocoCaptions
from tqdm import tqdm
from scipy.io import loadmat
from torchvision.transforms import ToTensor, Normalize
import open_clip


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model, _, _ = open_clip.create_model_and_transforms('ViT-B-32',
                                                    pretrained='./data/model/clip/mineclip/vitB/coco_5C_3_100.pt')
dtransforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
model = model.to(device)
layers = model.encode_image.encoder.layers.to(device)
coco_root = './data/datasets/coco'
ann_file = coco_root + 'annotations/captions_train2014.json'
img_dir = coco_root + 'train2014/'
dataset = CocoCaptions(img_dir, ann_file, dtransforms)
datacanary = loadmat("./data/mat/clipmem/coco/canarylist.mat")
canarylist = datacanary['c_o_list'].tolist()
canaryset = torch.utils.data.Subset(dataset, canarylist)
canarydataloader = DataLoader(canaryset, batch_size=1, shuffle=True, num_workers=4)
augmentation_set = transforms.RandomResizedCrop(size=224)
new_m = torchvision.models._utils.IntermediateLayerGetter(layers,{'11': 'feat1'})
if __name__ == '__main__':
    final1 = []
    for img, label in tqdm(iter(canarydataloader)):
        final = []
        img = img.to(device)
        for j in range(10):
            out = new_m(augmentation_set(img))
            for k, v in out.items():
                my = np.mean(v.reshape(256, 4).cpu().detach().numpy(), axis=1)
                final.append(my)
        out1 = np.mean(np.array(final), axis=0)
        final1.append(out1)

    finalout = np.array(final1)
    maxout = np.max(finalout, axis=0)
    medianout = np.median(np.sort(finalout, axis=0)[0:-1], axis=0)
    selectivity = (maxout - medianout)/(maxout + medianout)
    scipy.io.savemat('./data/clip/coco/unitmem_coco_1cap_image_11.mat', {'selectivity': selectivity})
